

<!DOCTYPE html>
<!--[if IE 8]><html class="no-js lt-ie9" lang="en" > <![endif]-->
<!--[if gt IE 8]><!--> <html class="no-js" lang="en" > <!--<![endif]-->
<head>
  <meta charset="utf-8">
  
  <meta name="viewport" content="width=device-width, initial-scale=1.0">
  
  <title>TensorFlow扩展 &mdash; 简单粗暴TensorFlow 0.3 beta 文档</title>
  

  
  
  
  

  

  
  
    

  

  
  
    <link rel="stylesheet" href="../_static/css/theme.css" type="text/css" />
  

  

  
        <link rel="index" title="索引"
              href="../genindex.html"/>
        <link rel="search" title="搜索" href="../search.html"/>
    <link rel="top" title="简单粗暴TensorFlow 0.3 beta 文档" href="../index.html"/>
        <link rel="next" title="附录：静态的TensorFlow" href="static.html"/>
        <link rel="prev" title="TensorFlow模型" href="models.html"/> 

  
  <script src="../_static/js/modernizr.min.js"></script>

</head>

<body class="wy-body-for-nav" role="document">

   
  <div class="wy-grid-for-nav">

    
    <nav data-toggle="wy-nav-shift" class="wy-nav-side">
      <div class="wy-side-scroll">
        <div class="wy-side-nav-search">
          

          
            <a href="../index.html" class="icon icon-home"> 简单粗暴TensorFlow
          

          
          </a>

          
            
            
              <div class="version">
                0.3
              </div>
            
          

          
<div role="search">
  <form id="rtd-search-form" class="wy-form" action="../search.html" method="get">
    <input type="text" name="q" placeholder="Search docs" />
    <input type="hidden" name="check_keywords" value="yes" />
    <input type="hidden" name="area" value="default" />
  </form>
</div>

          
        </div>

        <div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
          
            
            
              
            
            
              <p class="caption"><span class="caption-text">目录</span></p>
<ul class="current">
<li class="toctree-l1"><a class="reference internal" href="preface.html">前言</a></li>
<li class="toctree-l1"><a class="reference internal" href="installation.html">TensorFlow安装</a></li>
<li class="toctree-l1"><a class="reference internal" href="basic.html">TensorFlow基础</a></li>
<li class="toctree-l1"><a class="reference internal" href="models.html">TensorFlow模型</a></li>
<li class="toctree-l1 current"><a class="current reference internal" href="#">TensorFlow扩展</a><ul>
<li class="toctree-l2"><a class="reference internal" href="#checkpoint">Checkpoint：变量的保存与恢复</a></li>
<li class="toctree-l2"><a class="reference internal" href="#tensorboard">TensorBoard：训练过程可视化</a></li>
<li class="toctree-l2"><a class="reference internal" href="#gpu">GPU的使用与分配</a></li>
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="static.html">附录：静态的TensorFlow</a></li>
</ul>
<p class="caption"><span class="caption-text">Contents</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../en/preface.html">Preface</a></li>
<li class="toctree-l1"><a class="reference internal" href="../en/installation.html">TensorFlow Installation</a></li>
<li class="toctree-l1"><a class="reference internal" href="../en/basic.html">TensorFlow Basic</a></li>
<li class="toctree-l1"><a class="reference internal" href="../en/models.html">TensorFlow Models</a></li>
<li class="toctree-l1"><a class="reference internal" href="../en/extended.html">TensorFlow Extensions</a></li>
<li class="toctree-l1"><a class="reference internal" href="../en/static.html">Appendix: Static TensorFlow</a></li>
</ul>

            
          
        </div>
      </div>
    </nav>

    <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">

      
      <nav class="wy-nav-top" role="navigation" aria-label="top navigation">
        
          <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
          <a href="../index.html">简单粗暴TensorFlow</a>
        
      </nav>


      
      <div class="wy-nav-content">
        <div class="rst-content">
          















<div role="navigation" aria-label="breadcrumbs navigation">

  <ul class="wy-breadcrumbs">
    
      <li><a href="../index.html">Docs</a> &raquo;</li>
        
      <li>TensorFlow扩展</li>
    
    
      <li class="wy-breadcrumbs-aside">
        
            
            <a href="../_sources/zh/extended.rst.txt" rel="nofollow"> View page source</a>
          
        
      </li>
    
  </ul>

  
  <hr/>
</div>
          <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
           <div itemprop="articleBody">
            
  <div class="section" id="tensorflow">
<h1>TensorFlow扩展<a class="headerlink" href="#tensorflow" title="永久链接至标题">¶</a></h1>
<p>本章介绍一些最为常用的TensorFlow扩展功能。虽然这些功能称不上“必须”，但能让模型训练和调用的过程更加方便。</p>
<p>前置知识：</p>
<ul class="simple">
<li><a class="reference external" href="http://www.runoob.com/python3/python3-inputoutput.html">Python的序列化模块Pickle</a> （非必须）</li>
<li><a class="reference external" href="https://eastlakeside.gitbooks.io/interpy-zh/content/args_kwargs/Usage_kwargs.html">Python的特殊函数参数**kwargs</a> （非必须）</li>
</ul>
<div class="section" id="checkpoint">
<h2>Checkpoint：变量的保存与恢复<a class="headerlink" href="#checkpoint" title="永久链接至标题">¶</a></h2>
<p>很多时候，我们希望在模型训练完成后能将训练好的参数（变量）保存起来。在需要使用模型的其他地方载入模型和参数，就能直接得到训练好的模型。可能你第一个想到的是用Python的序列化模块 <code class="docutils literal"><span class="pre">pickle</span></code> 存储 <code class="docutils literal"><span class="pre">model.variables</span></code>。但不幸的是，TensorFlow的变量类型 <code class="docutils literal"><span class="pre">ResourceVariable</span></code> 并不能被序列化。</p>
<p>好在TensorFlow提供了 <a class="reference external" href="https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint">tf.train.Checkpoint</a> 这一强大的变量保存与恢复类，可以使用其 <code class="docutils literal"><span class="pre">save()</span></code> 和 <code class="docutils literal"><span class="pre">restore()</span></code> 方法将TensorFlow中所有包含Checkpointable State的对象进行保存和恢复。具体而言，<code class="docutils literal"><span class="pre">tf.train.Optimizer</span></code> 实现, <code class="docutils literal"><span class="pre">tf.Variable</span></code>, <code class="docutils literal"><span class="pre">tf.keras.Layer</span></code> 实现或者 <code class="docutils literal"><span class="pre">tf.keras.Model</span></code> 实现都可以被保存。其使用方法非常简单，我们首先声明一个Checkpoint：</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">checkpoint</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">train</span><span class="o">.</span><span class="n">Checkpoint</span><span class="p">(</span><span class="n">model</span><span class="o">=</span><span class="n">model</span><span class="p">)</span>
</pre></div>
</div>
<p>这里 <code class="docutils literal"><span class="pre">tf.train.Checkpoint()</span></code> 接受的初始化参数比较特殊，是一个 <code class="docutils literal"><span class="pre">**kwargs</span></code> 。具体而言，是一系列的键值对，键名可以随意取，值为需要保存的对象。例如，如果我们希望保存一个继承 <code class="docutils literal"><span class="pre">tf.keras.Model</span></code> 的模型实例 <code class="docutils literal"><span class="pre">model</span></code> 和一个继承 <code class="docutils literal"><span class="pre">tf.train.Optimizer</span></code> 的优化器 <code class="docutils literal"><span class="pre">optimizer</span></code> ，我们可以这样写：</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">checkpoint</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">train</span><span class="o">.</span><span class="n">Checkpoint</span><span class="p">(</span><span class="n">myAwesomeModel</span><span class="o">=</span><span class="n">model</span><span class="p">,</span> <span class="n">myAwesomeOptimizer</span><span class="o">=</span><span class="n">optimizer</span><span class="p">)</span>
</pre></div>
</div>
<p>这里 <code class="docutils literal"><span class="pre">myAwesomeModel</span></code> 是我们为待保存的模型 <code class="docutils literal"><span class="pre">model</span></code> 所取的任意键名。注意，在恢复变量的时候，我们还将使用这一键名。</p>
<p>接下来，当模型训练完成需要保存的时候，使用：</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">checkpoint</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">save_path_with_prefix</span><span class="p">)</span>
</pre></div>
</div>
<p>就可以。 <code class="docutils literal"><span class="pre">save_path_with_prefix</span></code> 是保存文件的目录+前缀。例如，在源代码目录建立一个名为save的文件夹并调用一次 <code class="docutils literal"><span class="pre">checkpoint.save('./save/model.ckpt')</span></code> ，我们就可以在可以在save目录下发现名为 <code class="docutils literal"><span class="pre">checkpoint</span></code> 、  <code class="docutils literal"><span class="pre">model.ckpt-1.index</span></code> 、 <code class="docutils literal"><span class="pre">model.ckpt-1.data-00000-of-00001</span></code> 的三个文件，这些文件就记录了变量信息。<code class="docutils literal"><span class="pre">checkpoint.save()</span></code> 方法可以运行多次，每运行一次都会得到一个.index文件和.data文件，序号依次累加。</p>
<p>当在其他地方需要为模型重新载入之前保存的参数时，需要再次实例化一个checkpoint，同时保持键名的一致。再调用checkpoint的restore方法。就像下面这样：</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">model_to_be_restored</span> <span class="o">=</span> <span class="n">MyModel</span><span class="p">()</span>                                        <span class="c1"># 待恢复参数的同一模型</span>
<span class="n">checkpoint</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">train</span><span class="o">.</span><span class="n">Checkpoint</span><span class="p">(</span><span class="n">myAwesomeModel</span><span class="o">=</span><span class="n">model_to_be_restored</span><span class="p">)</span>   <span class="c1"># 键名保持为“myAwesomeModel”</span>
<span class="n">checkpoint</span><span class="o">.</span><span class="n">restore</span><span class="p">(</span><span class="n">save_path_with_prefix_and_index</span><span class="p">)</span>
</pre></div>
</div>
<p>即可恢复模型变量。 <code class="docutils literal"><span class="pre">save_path_with_prefix_and_index</span></code> 是之前保存的文件的目录+前缀+编号。例如，调用 <code class="docutils literal"><span class="pre">checkpoint.restore('./save/model.ckpt-1')</span></code> 就可以载入前缀为 <code class="docutils literal"><span class="pre">model.ckpt</span></code> ，序号为1的文件来恢复模型。</p>
<p>当保存了多个文件时，我们往往想载入最近的一个。可以使用 <code class="docutils literal"><span class="pre">tf.train.latest_checkpoint(save_path)</span></code> 这个辅助函数返回目录下最近一次checkpoint的文件名。例如如果save目录下有 <code class="docutils literal"><span class="pre">model.ckpt-1.index</span></code> 到 <code class="docutils literal"><span class="pre">model.ckpt-10.index</span></code> 的10个保存文件， <code class="docutils literal"><span class="pre">tf.train.latest_checkpoint('./save')</span></code> 即返回 <code class="docutils literal"><span class="pre">./save/model.ckpt-10</span></code> 。</p>
<p>总体而言，恢复与保存变量的典型代码框架如下：</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="c1"># train.py 模型训练阶段</span>

<span class="n">model</span> <span class="o">=</span> <span class="n">MyModel</span><span class="p">()</span>
<span class="n">checkpoint</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">train</span><span class="o">.</span><span class="n">Checkpoint</span><span class="p">(</span><span class="n">myModel</span><span class="o">=</span><span class="n">model</span><span class="p">)</span>     <span class="c1"># 实例化Checkpoint，指定保存对象为model（如果需要保存Optimizer的参数也可加入）</span>
<span class="c1"># 模型训练代码</span>
<span class="n">checkpoint</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="s1">&#39;./save/model.ckpt&#39;</span><span class="p">)</span>                <span class="c1"># 模型训练完毕后将参数保存到文件，也可以在模型训练过程中每隔一段时间就保存一次</span>
</pre></div>
</div>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="c1"># test.py 模型使用阶段</span>

<span class="n">model</span> <span class="o">=</span> <span class="n">MyModel</span><span class="p">()</span>
<span class="n">checkpoint</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">train</span><span class="o">.</span><span class="n">Checkpoint</span><span class="p">(</span><span class="n">myModel</span><span class="o">=</span><span class="n">model</span><span class="p">)</span>             <span class="c1"># 实例化Checkpoint，指定恢复对象为model</span>
<span class="n">checkpoint</span><span class="o">.</span><span class="n">restore</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">train</span><span class="o">.</span><span class="n">latest_checkpoint</span><span class="p">(</span><span class="s1">&#39;./save&#39;</span><span class="p">))</span>    <span class="c1"># 从文件恢复模型参数</span>
<span class="c1"># 模型使用代码</span>
</pre></div>
</div>
<p>顺便一提， <code class="docutils literal"><span class="pre">tf.train.Checkpoint</span></code> 与以前版本常用的 <code class="docutils literal"><span class="pre">tf.train.Saver</span></code> 相比，强大之处在于其支持在Eager Execution下“延迟”恢复变量。具体而言，当调用了 <code class="docutils literal"><span class="pre">checkpoint.restore()</span></code> ，但模型中的变量还没有被建立的时候，Checkpoint可以等到变量被建立的时候再进行数值的恢复。Eager Execution下，模型中各个层的初始化和变量的建立是在模型第一次被调用的时候才进行的（好处在于可以根据输入的张量形状而自动确定变量形状，无需手动指定）。这意味着当模型刚刚被实例化的时候，其实里面还一个变量都没有，这时候使用以往的方式去恢复变量数值是一定会报错的。比如，你可以试试在train.py调用 <code class="docutils literal"><span class="pre">tf.keras.Model</span></code> 的 <code class="docutils literal"><span class="pre">save_weight()</span></code> 方法保存model的参数，并在test.py中实例化model后立即调用 <code class="docutils literal"><span class="pre">load_weight()</span></code> 方法，就会出错，只有当调用了一遍model之后再运行 <code class="docutils literal"><span class="pre">load_weight()</span></code> 方法才能得到正确的结果。可见， <code class="docutils literal"><span class="pre">tf.train.Checkpoint</span></code> 在这种情况下可以给我们带来相当大的便利。另外， <code class="docutils literal"><span class="pre">tf.train.Checkpoint</span></code> 同时也支持Graph Execution模式。</p>
<p>最后提供一个实例，以前章的 <a class="reference internal" href="models.html#mlp"><span class="std std-ref">多层感知机模型</span></a> 为例展示模型变量的保存和载入：</p>
<div class="highlight-default"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">tensorflow</span> <span class="k">as</span> <span class="nn">tf</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="kn">from</span> <span class="nn">zh.model.mlp.mlp</span> <span class="k">import</span> <span class="n">MLP</span>
<span class="kn">from</span> <span class="nn">zh.model.mlp.utils</span> <span class="k">import</span> <span class="n">DataLoader</span>

<span class="n">tf</span><span class="o">.</span><span class="n">enable_eager_execution</span><span class="p">()</span>
<span class="n">mode</span> <span class="o">=</span> <span class="s1">&#39;test&#39;</span>
<span class="n">num_batches</span> <span class="o">=</span> <span class="mi">1000</span>
<span class="n">batch_size</span> <span class="o">=</span> <span class="mi">50</span>
<span class="n">learning_rate</span> <span class="o">=</span> <span class="mf">0.001</span>
<span class="n">data_loader</span> <span class="o">=</span> <span class="n">DataLoader</span><span class="p">()</span>


<span class="k">def</span> <span class="nf">train</span><span class="p">():</span>
    <span class="n">model</span> <span class="o">=</span> <span class="n">MLP</span><span class="p">()</span>
    <span class="n">optimizer</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">train</span><span class="o">.</span><span class="n">AdamOptimizer</span><span class="p">(</span><span class="n">learning_rate</span><span class="o">=</span><span class="n">learning_rate</span><span class="p">)</span>
    <span class="n">checkpoint</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">train</span><span class="o">.</span><span class="n">Checkpoint</span><span class="p">(</span><span class="n">myAwesomeModel</span><span class="o">=</span><span class="n">model</span><span class="p">)</span>      <span class="c1"># 实例化Checkpoint，设置保存对象为model</span>
    <span class="k">for</span> <span class="n">batch_index</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_batches</span><span class="p">):</span>
        <span class="n">X</span><span class="p">,</span> <span class="n">y</span> <span class="o">=</span> <span class="n">data_loader</span><span class="o">.</span><span class="n">get_batch</span><span class="p">(</span><span class="n">batch_size</span><span class="p">)</span>
        <span class="k">with</span> <span class="n">tf</span><span class="o">.</span><span class="n">GradientTape</span><span class="p">()</span> <span class="k">as</span> <span class="n">tape</span><span class="p">:</span>
            <span class="n">y_logit_pred</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">convert_to_tensor</span><span class="p">(</span><span class="n">X</span><span class="p">))</span>
            <span class="n">loss</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">sparse_softmax_cross_entropy</span><span class="p">(</span><span class="n">labels</span><span class="o">=</span><span class="n">y</span><span class="p">,</span> <span class="n">logits</span><span class="o">=</span><span class="n">y_logit_pred</span><span class="p">)</span>
            <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;batch </span><span class="si">%d</span><span class="s2">: loss </span><span class="si">%f</span><span class="s2">&quot;</span> <span class="o">%</span> <span class="p">(</span><span class="n">batch_index</span><span class="p">,</span> <span class="n">loss</span><span class="o">.</span><span class="n">numpy</span><span class="p">()))</span>
        <span class="n">grads</span> <span class="o">=</span> <span class="n">tape</span><span class="o">.</span><span class="n">gradient</span><span class="p">(</span><span class="n">loss</span><span class="p">,</span> <span class="n">model</span><span class="o">.</span><span class="n">variables</span><span class="p">)</span>
        <span class="n">optimizer</span><span class="o">.</span><span class="n">apply_gradients</span><span class="p">(</span><span class="n">grads_and_vars</span><span class="o">=</span><span class="nb">zip</span><span class="p">(</span><span class="n">grads</span><span class="p">,</span> <span class="n">model</span><span class="o">.</span><span class="n">variables</span><span class="p">))</span>
        <span class="k">if</span> <span class="p">(</span><span class="n">batch_index</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">%</span> <span class="mi">100</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>                        <span class="c1"># 每隔100个Batch保存一次</span>
            <span class="n">checkpoint</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="s1">&#39;./save/model.ckpt&#39;</span><span class="p">)</span>                <span class="c1"># 保存模型参数到文件</span>


<span class="k">def</span> <span class="nf">test</span><span class="p">():</span>
    <span class="n">model_to_be_restored</span> <span class="o">=</span> <span class="n">MLP</span><span class="p">()</span>
    <span class="n">checkpoint</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">train</span><span class="o">.</span><span class="n">Checkpoint</span><span class="p">(</span><span class="n">myAwesomeModel</span><span class="o">=</span><span class="n">model_to_be_restored</span><span class="p">)</span>      <span class="c1"># 实例化Checkpoint，设置恢复对象为新建立的模型model_to_be_restored</span>
    <span class="n">checkpoint</span><span class="o">.</span><span class="n">restore</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">train</span><span class="o">.</span><span class="n">latest_checkpoint</span><span class="p">(</span><span class="s1">&#39;./save&#39;</span><span class="p">))</span>    <span class="c1"># 从文件恢复模型参数</span>
    <span class="n">num_eval_samples</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">data_loader</span><span class="o">.</span><span class="n">eval_labels</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
    <span class="n">y_pred</span> <span class="o">=</span> <span class="n">model_to_be_restored</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">constant</span><span class="p">(</span><span class="n">data_loader</span><span class="o">.</span><span class="n">eval_data</span><span class="p">))</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span>
    <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;test accuracy: </span><span class="si">%f</span><span class="s2">&quot;</span> <span class="o">%</span> <span class="p">(</span><span class="nb">sum</span><span class="p">(</span><span class="n">y_pred</span> <span class="o">==</span> <span class="n">data_loader</span><span class="o">.</span><span class="n">eval_labels</span><span class="p">)</span> <span class="o">/</span> <span class="n">num_eval_samples</span><span class="p">))</span>


<span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s1">&#39;__main__&#39;</span><span class="p">:</span>
    <span class="k">if</span> <span class="n">mode</span> <span class="o">==</span> <span class="s1">&#39;train&#39;</span><span class="p">:</span>
        <span class="n">train</span><span class="p">()</span>
    <span class="k">if</span> <span class="n">mode</span> <span class="o">==</span> <span class="s1">&#39;test&#39;</span><span class="p">:</span>
        <span class="n">test</span><span class="p">()</span>
</pre></div>
</div>
<p>在代码目录下建立save文件夹并运行代码进行训练后，save文件夹内将会存放每隔100个batch保存一次的模型变量数据。将第7行改为 <code class="docutils literal"><span class="pre">model</span> <span class="pre">=</span> <span class="pre">'test'</span></code> 并再次运行代码，将直接使用最后一次保存的变量值恢复模型并在测试集上测试模型性能，可以直接获得95%左右的准确率。</p>
</div>
<div class="section" id="tensorboard">
<h2>TensorBoard：训练过程可视化<a class="headerlink" href="#tensorboard" title="永久链接至标题">¶</a></h2>
<p>有时，你希望查看模型训练过程中各个参数的变化情况（例如损失函数loss的值）。虽然可以通过命令行输出来查看，但有时显得不够直观。而TensorBoard就是一个能够帮助我们将训练过程可视化的工具。</p>
<p>目前，Eager Execution模式下的TensorBoard支持尚在 <a class="reference external" href="https://www.tensorflow.org/api_docs/python/tf/contrib/summary">tf.contrib.summary</a> 内，可能以后会有较多变化，因此这里只做简单示例。首先在代码目录下建立一个文件夹（如./tensorboard）存放TensorBoard的记录文件，并在代码中实例化一个记录器：</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">summary_writer</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">contrib</span><span class="o">.</span><span class="n">summary</span><span class="o">.</span><span class="n">create_file_writer</span><span class="p">(</span><span class="s1">&#39;./tensorboard&#39;</span><span class="p">)</span>
</pre></div>
</div>
<p>接下来，将训练的代码部分通过with语句放在 <code class="docutils literal"><span class="pre">summary_writer.as_default()</span></code> 和 <code class="docutils literal"><span class="pre">tf.contrib.summary.always_record_summaries()</span></code> 的上下文中，并对需要记录的参数（一般是scalar）运行 <code class="docutils literal"><span class="pre">tf.contrib.summary.scalar(name,</span> <span class="pre">tensor,</span> <span class="pre">step=batch_index)</span></code> 即可。这里的step参数可根据自己的需要自行制定，一般可设置为当前训练过程中的batch序号。整体框架如下：</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">summary_writer</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">contrib</span><span class="o">.</span><span class="n">summary</span><span class="o">.</span><span class="n">create_file_writer</span><span class="p">(</span><span class="s1">&#39;./tensorboard&#39;</span><span class="p">)</span>
<span class="k">with</span> <span class="n">summary_writer</span><span class="o">.</span><span class="n">as_default</span><span class="p">(),</span> <span class="n">tf</span><span class="o">.</span><span class="n">contrib</span><span class="o">.</span><span class="n">summary</span><span class="o">.</span><span class="n">always_record_summaries</span><span class="p">():</span>
    <span class="c1"># 开始模型训练</span>
    <span class="k">for</span> <span class="n">batch_index</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_batches</span><span class="p">):</span>
        <span class="c1"># 训练代码，当前batch的损失值放入变量loss中</span>
        <span class="n">tf</span><span class="o">.</span><span class="n">contrib</span><span class="o">.</span><span class="n">summary</span><span class="o">.</span><span class="n">scalar</span><span class="p">(</span><span class="s2">&quot;loss&quot;</span><span class="p">,</span> <span class="n">loss</span><span class="p">,</span> <span class="n">step</span><span class="o">=</span><span class="n">batch_index</span><span class="p">)</span>
        <span class="n">tf</span><span class="o">.</span><span class="n">contrib</span><span class="o">.</span><span class="n">summary</span><span class="o">.</span><span class="n">scalar</span><span class="p">(</span><span class="s2">&quot;MyScalar&quot;</span><span class="p">,</span> <span class="n">my_scalar</span><span class="p">,</span> <span class="n">step</span><span class="o">=</span><span class="n">batch_index</span><span class="p">)</span>  <span class="c1"># 还可以添加其他自定义的变量</span>
</pre></div>
</div>
<p>每运行一次 <code class="docutils literal"><span class="pre">tf.contrib.summary.scalar()</span></code> ，记录器就会向记录文件中写入一条记录。除了最简单的标量（scalar）以外，TensorBoard还可以对其他类型的数据（如图像，音频等）进行可视化，详见 <a class="reference external" href="https://www.tensorflow.org/api_docs/python/tf/contrib/summary">API文档</a> 。</p>
<p>当我们要对训练过程可视化时，在代码目录打开终端（如需要的话进入TensorFlow的conda环境），运行:</p>
<div class="highlight-default"><div class="highlight"><pre><span></span><span class="n">tensorboard</span> <span class="o">--</span><span class="n">logdir</span><span class="o">=./</span><span class="n">tensorboard</span>
</pre></div>
</div>
<p>然后使用浏览器访问命令行程序所输出的网址（一般是http://计算机名称:6006），即可访问TensorBoard的可视界面，如下图所示：</p>
<div class="figure align-center">
<a class="reference internal image-reference" href="../_images/tensorboard1.png"><img alt="../_images/tensorboard1.png" src="../_images/tensorboard1.png" style="width: 100%;" /></a>
</div>
<p>默认情况下，TensorBoard每30秒更新一次数据。不过也可以点击右上角的刷新按钮手动刷新。</p>
<p>TensorBoard的使用有以下注意事项：</p>
<ul class="simple">
<li>如果需要重新训练，需要删除掉记录文件夹内的信息并重启TensorBoard（或者建立一个新的记录文件夹并开启TensorBoard， <code class="docutils literal"><span class="pre">--logdir</span></code> 参数设置为新建立的文件夹）；</li>
<li>记录文件夹目录保持全英文。</li>
</ul>
<p>最后提供一个实例，以前章的 <a class="reference internal" href="models.html#mlp"><span class="std std-ref">多层感知机模型</span></a> 为例展示TensorBoard的使用：</p>
<div class="highlight-default"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">tensorflow</span> <span class="k">as</span> <span class="nn">tf</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="kn">from</span> <span class="nn">zh.model.mlp.mlp</span> <span class="k">import</span> <span class="n">MLP</span>
<span class="kn">from</span> <span class="nn">zh.model.mlp.utils</span> <span class="k">import</span> <span class="n">DataLoader</span>

<span class="n">tf</span><span class="o">.</span><span class="n">enable_eager_execution</span><span class="p">()</span>
<span class="n">num_batches</span> <span class="o">=</span> <span class="mi">10000</span>
<span class="n">batch_size</span> <span class="o">=</span> <span class="mi">50</span>
<span class="n">learning_rate</span> <span class="o">=</span> <span class="mf">0.001</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">MLP</span><span class="p">()</span>
<span class="n">data_loader</span> <span class="o">=</span> <span class="n">DataLoader</span><span class="p">()</span>
<span class="n">optimizer</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">train</span><span class="o">.</span><span class="n">AdamOptimizer</span><span class="p">(</span><span class="n">learning_rate</span><span class="o">=</span><span class="n">learning_rate</span><span class="p">)</span>
<span class="n">summary_writer</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">contrib</span><span class="o">.</span><span class="n">summary</span><span class="o">.</span><span class="n">create_file_writer</span><span class="p">(</span><span class="s1">&#39;./tensorboard&#39;</span><span class="p">)</span>     <span class="c1"># 实例化记录器</span>
<span class="k">with</span> <span class="n">summary_writer</span><span class="o">.</span><span class="n">as_default</span><span class="p">(),</span> <span class="n">tf</span><span class="o">.</span><span class="n">contrib</span><span class="o">.</span><span class="n">summary</span><span class="o">.</span><span class="n">always_record_summaries</span><span class="p">():</span>
    <span class="k">for</span> <span class="n">batch_index</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_batches</span><span class="p">):</span>
        <span class="n">X</span><span class="p">,</span> <span class="n">y</span> <span class="o">=</span> <span class="n">data_loader</span><span class="o">.</span><span class="n">get_batch</span><span class="p">(</span><span class="n">batch_size</span><span class="p">)</span>
        <span class="k">with</span> <span class="n">tf</span><span class="o">.</span><span class="n">GradientTape</span><span class="p">()</span> <span class="k">as</span> <span class="n">tape</span><span class="p">:</span>
            <span class="n">y_logit_pred</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">convert_to_tensor</span><span class="p">(</span><span class="n">X</span><span class="p">))</span>
            <span class="n">loss</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">sparse_softmax_cross_entropy</span><span class="p">(</span><span class="n">labels</span><span class="o">=</span><span class="n">y</span><span class="p">,</span> <span class="n">logits</span><span class="o">=</span><span class="n">y_logit_pred</span><span class="p">)</span>
            <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;batch </span><span class="si">%d</span><span class="s2">: loss </span><span class="si">%f</span><span class="s2">&quot;</span> <span class="o">%</span> <span class="p">(</span><span class="n">batch_index</span><span class="p">,</span> <span class="n">loss</span><span class="o">.</span><span class="n">numpy</span><span class="p">()))</span>
            <span class="n">tf</span><span class="o">.</span><span class="n">contrib</span><span class="o">.</span><span class="n">summary</span><span class="o">.</span><span class="n">scalar</span><span class="p">(</span><span class="s2">&quot;loss&quot;</span><span class="p">,</span> <span class="n">loss</span><span class="p">,</span> <span class="n">step</span><span class="o">=</span><span class="n">batch_index</span><span class="p">)</span>       <span class="c1"># 记录当前loss</span>
        <span class="n">grads</span> <span class="o">=</span> <span class="n">tape</span><span class="o">.</span><span class="n">gradient</span><span class="p">(</span><span class="n">loss</span><span class="p">,</span> <span class="n">model</span><span class="o">.</span><span class="n">variables</span><span class="p">)</span>
        <span class="n">optimizer</span><span class="o">.</span><span class="n">apply_gradients</span><span class="p">(</span><span class="n">grads_and_vars</span><span class="o">=</span><span class="nb">zip</span><span class="p">(</span><span class="n">grads</span><span class="p">,</span> <span class="n">model</span><span class="o">.</span><span class="n">variables</span><span class="p">))</span>
</pre></div>
</div>
</div>
<div class="section" id="gpu">
<h2>GPU的使用与分配<a class="headerlink" href="#gpu" title="永久链接至标题">¶</a></h2>
<p>很多时候的场景是：实验室/公司研究组里有许多学生/研究员都需要使用GPU，但多卡的机器只有一台，这时就需要注意如何分配显卡资源。</p>
<p>命令 <code class="docutils literal"><span class="pre">nvidia-smi</span></code> 可以查看机器上现有的GPU及使用情况（在Windows下，将 <code class="docutils literal"><span class="pre">C:\Program</span> <span class="pre">Files\NVIDIA</span> <span class="pre">Corporation\NVSMI</span></code> 加入Path环境变量中即可，或Windows 10下可使用任务管理器的“性能”标签查看显卡信息）。</p>
<p>使用环境变量 <code class="docutils literal"><span class="pre">CUDA_VISIBLE_DEVICES</span></code> 可以控制程序所使用的GPU。假设发现四卡的机器上显卡0,1使用中，显卡2,3空闲，Linux终端输入:</p>
<div class="highlight-default"><div class="highlight"><pre><span></span><span class="n">export</span> <span class="n">CUDA_VISIBLE_DEVICES</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span><span class="mi">3</span>
</pre></div>
</div>
<p>或在代码中加入</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">os</span>
<span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="s1">&#39;CUDA_VISIBLE_DEVICES&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="s2">&quot;2,3&quot;</span>
</pre></div>
</div>
<p>即可指定程序只在显卡2,3上运行。</p>
<p>默认情况下，TensorFlow将使用几乎所有可用的显存，以避免内存碎片化所带来的性能损失。可以通过 <code class="docutils literal"><span class="pre">tf.ConfigProto</span></code> 类来设置TensorFlow使用显存的策略。具体方式是实例化一个 <code class="docutils literal"><span class="pre">tf.ConfigProto</span></code> 类，设置参数，并在运行 <code class="docutils literal"><span class="pre">tf.enable_eager_execution()</span></code> 时指定Config参数。以下代码通过 <code class="docutils literal"><span class="pre">allow_growth</span></code> 选项设置TensorFlow仅在需要时申请显存空间：</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">config</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">ConfigProto</span><span class="p">()</span>
<span class="n">config</span><span class="o">.</span><span class="n">gpu_options</span><span class="o">.</span><span class="n">allow_growth</span> <span class="o">=</span> <span class="bp">True</span>
<span class="n">tf</span><span class="o">.</span><span class="n">enable_eager_execution</span><span class="p">(</span><span class="n">config</span><span class="o">=</span><span class="n">config</span><span class="p">)</span>
</pre></div>
</div>
<p>以下代码通过 <code class="docutils literal"><span class="pre">per_process_gpu_memory_fraction</span></code> 选项设置TensorFlow固定消耗40%的GPU显存：</p>
<div class="highlight-python"><div class="highlight"><pre><span></span><span class="n">config</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">ConfigProto</span><span class="p">()</span>
<span class="n">config</span><span class="o">.</span><span class="n">gpu_options</span><span class="o">.</span><span class="n">per_process_gpu_memory_fraction</span> <span class="o">=</span> <span class="mf">0.4</span>
<span class="n">tf</span><span class="o">.</span><span class="n">enable_eager_execution</span><span class="p">(</span><span class="n">config</span><span class="o">=</span><span class="n">config</span><span class="p">)</span>
</pre></div>
</div>
<p>Graph Execution下，也可以在实例化新的session时传入 tf.ConfigPhoto 类来进行设置。</p>
</div>
</div>


           </div>
           <div class="articleComments">
            
           </div>
          </div>
          <footer>
  
    <div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
      
        <a href="static.html" class="btn btn-neutral float-right" title="附录：静态的TensorFlow" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right"></span></a>
      
      
        <a href="models.html" class="btn btn-neutral" title="TensorFlow模型" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left"></span> Previous</a>
      
    </div>
  

  <hr/>

  <div role="contentinfo">
    <p>
        &copy; Copyright 2018, Xihan Li（雪麒）.

    </p>
  </div>
  Built with <a href="http://sphinx-doc.org/">Sphinx</a> using a <a href="https://github.com/snide/sphinx_rtd_theme">theme</a> provided by <a href="https://readthedocs.org">Read the Docs</a>. 

</footer>

        </div>
      </div>

    </section>

  </div>
  


  

    <script type="text/javascript">
        var DOCUMENTATION_OPTIONS = {
            URL_ROOT:'../',
            VERSION:'0.3 beta',
            COLLAPSE_INDEX:false,
            FILE_SUFFIX:'.html',
            HAS_SOURCE:  true,
            SOURCELINK_SUFFIX: '.txt'
        };
    </script>
      <script type="text/javascript" src="../_static/jquery.js"></script>
      <script type="text/javascript" src="../_static/underscore.js"></script>
      <script type="text/javascript" src="../_static/doctools.js"></script>
      <script type="text/javascript" src="../_static/translations.js"></script>

  

  
  
    <script type="text/javascript" src="../_static/js/theme.js"></script>
  

  
  
  <script type="text/javascript">
      jQuery(function () {
          SphinxRtdTheme.StickyNav.enable();
      });
  </script>
   

</body>
</html>